
import os
import itertools

import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.io import loadmat
import scipy.stats as stats
from statsmodels.stats.anova import AnovaRM
from statsmodels.stats.libqsturng import psturng
from statsmodels.stats.multicomp import pairwise_tukeyhsd
#from statsmodels.stats.proportion import proportions_chisquare_pairscontrol as dunnett ## Don't now how to use

import annotations
#from courtana.jaaba import read_behavior_predictions


# This "_mod" version has only one change: in the "calc_detection_accuracy" function, an extra argument is output - a list
# of start frames for each detected aggressive encounter. Required for the "aggressive_encounters_timing_during_copulation"
# histogram.
def calc_detection_accuracy(aggression_timepoints, detected_encounters, threshold=0.20, use_thresh=False):
    '''
    Calculates which fraction of detected events overlaps with actual annotated aggression events.
    Takes the following arguments:
    
    aggression_timepoints:   (list); a list of tuples representing a group of annotated aggression events.
    detected_encounters:     (list); a list of tuples representing a group of events created through our thresholds.
    threshold:               (float); minimum fraction of frames that must overlap between an aggression and detected event
                                      in order to consider them really overlapping.
    use_thresh:              (bool); whether to use the threshold for detecting overlap or not.
    
    Returns a tuple containing: the number of overlapping aggression and detected events, the total number of aggression events,
    and the calculated accuracy score, defined as the number of overlaps divided by the total number of aggession events.
    '''
    
    n_aggression_events = len(aggression_timepoints)
    
    detected_aggression = []
    intersections = []
    aggressive_encounters_start_frame = []
    encounter_index = 0
    for timepoint in aggression_timepoints:
        aggression_start, aggression_end = timepoint
        
        for e, encounter in enumerate(detected_encounters[encounter_index:]):
            encounter_start, encounter_end = encounter
            
            array_length = max(aggression_end, encounter_end)
            
            initial_array = np.zeros(array_length, dtype=bool)
            
            aggression_array = initial_array.copy()
            aggression_array[aggression_start:aggression_end] = True
            
            encounter_array = initial_array.copy()
            encounter_array[encounter_start:encounter_end] = True
            
#             intersection = sum(np.logical_and(aggression_array, encounter_array)) / array_length
            
            intersection = np.logical_and(aggression_array, encounter_array)
            
            fraction_intersection = np.divide(np.sum(intersection), np.sum(aggression_array))
            
            if use_thresh:
                if fraction_intersection >= threshold:
                    detected_aggression.append(timepoint)
                    intersections.append(fraction_intersection)
                    aggressive_encounters_start_frame.append(encounter_start)
                    encounter_index += e
                    #print(timepoint, encounter, 'Detected.')
                    break
            else:
                if sum(intersection) != 0:
                    detected_aggression.append(timepoint)
                    intersections.append(fraction_intersection)
                    aggressive_encounters_start_frame.append(encounter_start)
                    encounter_index += e
#                     print(timepoint, encounter, 'Detected.')
                    break
    
    n_detected_aggression = len(detected_aggression)
    fraction_detected = n_detected_aggression / n_aggression_events
    
    print('Number of aggressive bouts is:', n_aggression_events)
    print('Number of clean detected "encounters" is:', len(detected_encounters))
    print('Number of detected aggressive bouts is:', n_detected_aggression)
    
    return (n_detected_aggression, n_aggression_events, fraction_detected, intersections, aggressive_encounters_start_frame)


def check_outliers(data, method='both', severity=3):
    '''
    Uses the IQR and z-score methods to determine which data points are considered outliers, and returns their index in the original DataFrame.
    Takes the following aruments:
    
    data:      (pandas.Series, numpy array, or list); the dataset to be checked for outliers.
    method:    (str); which detection method to apply. IQR is less restrictive, while z-score is more restrictive.
                      Default is "both", which returns only those outliers detected by both methods. This is the more conservative approach.
    severity:  (int); how many standard deviations to use when calculating outliers with the z-score method.
                   
    Returns a list of indices.
    '''
    
    # Convert input into a numpy array.
    data_val = data.values

    # Apply the Inter-Quartile Range method.
    q1 = np.quantile(data_val, q=0.25)
    q3 = np.quantile(data_val, q=0.75)
    iqr = stats.iqr(data_val)
    
    iqr_lower_bound = q1 - 1.5 * iqr
    iqr_upper_bound = q3 + 1.5 * iqr
    
    iqr_lower_outliers = np.unique(data_val[data_val < iqr_lower_bound])
    iqr_upper_outliers = np.unique(data_val[data_val > iqr_upper_bound])
    
    lower_iqr_outlier_indices = np.array([np.where(data_val==outlier) for outlier in iqr_lower_outliers])

    # Because the several observations may have the same value and therefore be present at many different indices, we need to flatten the resulting array of arrays.
    if len(lower_iqr_outlier_indices) > 0:
        lower_iqr_outlier_indices = np.concatenate(np.concatenate(lower_iqr_outlier_indices))
    upper_iqr_outlier_indices = np.array([np.where(data_val==outlier) for outlier in iqr_upper_outliers])
    if len(upper_iqr_outlier_indices) > 0:
        upper_iqr_outlier_indices = np.concatenate(np.concatenate(upper_iqr_outlier_indices))

    iqr_indices_together = np.append(lower_iqr_outlier_indices, upper_iqr_outlier_indices)
    total_iqr_outlier_indices = [int(index) for index in iqr_indices_together]

    iqr_outliers_index_labels = list(data.iloc[total_iqr_outlier_indices].index)
    
    # Apply the z-score method.
    zscore = stats.zscore(data)
    
    zscore_lower_outliers = np.unique(zscore[zscore < -severity])
    zscore_upper_outliers = np.unique(zscore[zscore > severity])
    
    lower_zscore_outlier_indices = np.array([np.where(zscore==outlier) for outlier in zscore_lower_outliers])
    # Because the several observations may have the same value and therefore be present at many different indices, we need to flatten the resulting array of arrays.
    if len(lower_zscore_outlier_indices) > 0:
        lower_zscore_outlier_indices = np.concatenate(np.concatenate(lower_zscore_outlier_indices))
    upper_zscore_outlier_indices = np.array([np.where(zscore==outlier) for outlier in zscore_upper_outliers])
    if len(upper_zscore_outlier_indices) > 0:
        upper_zscore_outlier_indices = np.concatenate(np.concatenate(upper_zscore_outlier_indices))
    
    zscore_indices_together = np.append(lower_zscore_outlier_indices, upper_zscore_outlier_indices)
    total_zscore_outlier_indices = [int(index) for index in zscore_indices_together]
    
    zscore_outliers_index_labels = list(data.iloc[total_zscore_outlier_indices].index)
    
    if method == 'iqr':
        return iqr_outliers_index_labels
    elif method == 'zscore':
        return zscore_outliers_index_labels
    else:
        common_outlier_indices = np.intersect1d(iqr_outliers_index_labels, zscore_outliers_index_labels)
        return list(common_outlier_indices)


def detect_events(series):
    '''
    Get start and end frames for each event in a pandas Series.
    Takes the following arguments:
    
    series   (pandas.Series); Series from which to detect event starts and ends.

    Returns two lists, one with the frames at which each event in the Series starts, and another with
    the frames at which each event in the Series ends.
    '''

    # Read predictions by converting them to a numpy array of ints.
    scores = np.array(series.astype(int))
    
    # Get event starts, i. e., false-to-true transitions, and event ends, i.e, true-to-false transitions.
    false_to_true = scores[:-1] < scores[1:]
    true_to_false = scores[:-1] > scores[1:]

    # Get frame (i.e., index) where events start and end.
    behavior_on = np.flatnonzero(false_to_true)
    behavior_off = np.flatnonzero(true_to_false)

    # Detect events that start at the beggining of the video.
    if scores[0] == True:
        behavior_on = np.insert(behavior_on, 0, 0)
    
    # Detect events that end at the end of the video.
    if scores[-1] == True:
        behavior_off = np.append(behavior_off, len(scores)-1)
    
    assert behavior_on.size == behavior_off.size
        
    return behavior_on, behavior_off


def draw_components(params, number, axis):
    '''
    Draws several calculated compenents into a matplotlib axis.
    Takes the following arguments:
    
    params  (dict); dictionary of fly ellipse parameters.
    number  (int); iteration number. This indicates in which axis of the
                   figure the current components should be placed.
    axis    (matplotlib.Axis); the axis object on which to draw.
    
    Has no return value.
    '''

    for behavior in ['aggressive', 'mating']:

        axis[number].add_patch(params[behavior]['ellipse'])

        # Plot female's centroid.
        axis[number].plot(params[behavior]['centroid'][0], params[behavior]['centroid'][1], 'ro', markersize=5)

        # Plot head and tail positions.
        axis[number].plot(params[behavior]['head'][0], params[behavior]['head'][1], 'rx', markersize=8, markeredgewidth=3)
        axis[number].plot(params[behavior]['tail'][0], params[behavior]['tail'][1], 'kx', markersize=8, markeredgewidth=3)

        # Plot major axis using head and tail coordinates.
        axis[number].plot([params[behavior]['head'][0], params[behavior]['tail'][0]], [params[behavior]['head'][1], params[behavior]['tail'][1]], 'red', linewidth=2)

        ### GET MAJOR AXIS LINE EQUATION.

        # Plot line that goes through and extends beyond the fly's major axis.
        axis[number].plot(params[behavior]['major_x'], params[behavior]['major_y'], 'g')

        ### GET FLANKS

        # Plot left and right flank positions.
        axis[number].plot(params[behavior]['left_flank'][0], params[behavior]['left_flank'][1], 'kx', markersize=8)
        axis[number].plot(params[behavior]['right_flank'][0], params[behavior]['right_flank'][1], 'kx', markersize=8)

        # Plot minor axis using flanks coordinates.
        axis[number].plot([params[behavior]['left_flank'][0], params[behavior]['right_flank'][0]], [params[behavior]['left_flank'][1], params[behavior]['right_flank'][1]], 'red', linewidth=2)

        ### GET MINOR AXIS LINE EQUATION.

        # Plot line that goes through and extends beyond the fly's minor axis.
        axis[number].plot(params[behavior]['minor_x'], params[behavior]['minor_y'], 'g')
        
    # Plot intersection point of the flies' lines.
    axis[number].plot(params['intersect_coords'][0], params['intersect_coords'][1], 'ko', markersize=2)

    # Plot line from mating female centroid to lines intersection point.
#     ax.plot((params['mating']['centroid'][0], params['intersect_coords'][0]), (params['mating']['centroid'][1], params['intersect_coords'][1]), 'k')

    # Plot intersection point of the aggressive female line with mating female ellipse.
    axis[number].plot(params['ellipse_point'][0], params['ellipse_point'][1], 'rx', markersize=5)

    # Plot line from mating female centroid to body point.
    axis[number].plot((params['mating']['centroid'][0], params['ellipse_point'][0]), (params['mating']['centroid'][1], params['ellipse_point'][1]), 'k')
    
    # Plot intersection point of the body point line with mating female major axis.
    axis[number].plot(params['body_intersection_point'][0], params['body_intersection_point'][1], 'rx')
    
    # Plot line from mating female centroid to body intersection point.
    axis[number].plot((params['mating']['centroid'][0], params['body_intersection_point'][0]), (params['mating']['centroid'][1], params['body_intersection_point'][1]), 'k')
    
    # Plot line from mating body point to body intersection point.
    axis[number].plot((params['ellipse_point'][0], params['body_intersection_point'][0]), (params['ellipse_point'][1], params['body_intersection_point'][1]), 'k')
    
    # Add calculated angle to body point as a circumference arc.
    axis[number].add_patch(params['mating_arc'])
    
    axis[number].set_title('Bout #' + str(number+1), y=0.85)
    
    axis[number].set_xlim(params['mating']['centroid'][0] - 2.0, params['mating']['centroid'][0] + 2.0)
    axis[number].set_ylim(params['mating']['centroid'][1] - 2.0, params['mating']['centroid'][1] + 2.0)


def events_to_annotation(timepoints, behaviors, experiment, savepath=os.getcwd()):
    '''
    For each behavior, takes a list of tuples with start and end frames for each event,
    and converts them into an Annotation file readable by PythonVideoAnnotator.
    Takes the following arguments:
    
    timepoints:   (list); a list of tuples representing a group of events.
    behaviors:    (list); list strings representing the behaviors we want to add to file.
                          This will also set the name of the Tracks in PythonVideoAnnotator.
                          Must be the same length as 'timepoints'.
    experiment:   (str); name for our saved file.
    savepath:     (str); path in which to save our file.

    Returns no output variable.
    '''

    filename = os.path.join(savepath, experiment + '.csv')

    # If the file with our name already exists at our save path, then delete the existing
    # version, so that we don't keep appending to the same file.
    if os.path.isfile(filename):
        os.remove(filename)
    
    # Set colors for each behavior track.
    color_set = mpl.cm.get_cmap(name='Dark2', lut=None)
    hex_color_set = [mpl.colors.to_hex(color) for color in color_set.colors]
    raw_color_set = [hex_color_set[1]] * 4
    clean_color_set = [hex_color_set[2]] * 4
    clean_vel_peaks_raw_color_set = [hex_color_set[3]] * 4
    clean_vel_peaks_rolling_window3_color_set = [hex_color_set[4]] * 4
    complete_color_set = [hex_color_set[0], *raw_color_set, *clean_color_set, *clean_vel_peaks_raw_color_set, *clean_vel_peaks_rolling_window3_color_set]

    # Write to file.
    for b, behavior in enumerate(behaviors):

        track = 'T ' + behavior + ' ' + complete_color_set[b] + '\n'
        track_comma = track.replace(' ', ',')
        
        timestamps =  timepoints[b]
        with open(filename, 'a') as file:
            file.write(track_comma)
            for t in timestamps:
                line = 'P False ' + str(t[0]) + ' ' + str(t[1]) + ',,' + complete_color_set[b] + ' ' + str(b) + '\n'
                line_comma = line.replace(' ', ',')
                file.write(line_comma)


def fix_orientation_flips(series, tolerance=15, extra_data=False):
    '''
    Fixes random orientation flips from FlyTracker tracking output.
    Takes the following arguments:
    
    series:      (pandas.Series); pandas Series with fly orientation data to correct.
    tolerance:   (int); angle (in degrees) to allow for variation in the orientation flip.
    extra_data:  (bool); whether or not to return additional data for inspection.
    
    Returns a pandas Series with flipped orientations.
    '''
    
    index_flips = []
    ori_diffs = []
    temp_series = series.copy()
    for index in temp_series.index[:-1]:
        previous_index = index
        current_index = index + 1
        ori_diff = temp_series[current_index] - temp_series[previous_index]
        if abs(ori_diff) >= 180 - tolerance and abs(ori_diff) <= 180 + tolerance:
            index_flips.append(index)
            ori_diffs.append(ori_diff)
            temp_series[current_index] -= np.sign(ori_diff) * 180
    
    if extra_data == True:
        return (temp_series, index_flips, ori_diffs)
        
    return temp_series


def get_body_point(point_a, point_b, fly_parameters, precision=100):
    '''
    Computes the coordinates of the intersection point between a line and an ellipse.
    Takes the following arguments:
    
    point_a         (tuple); (x, y) coordinates of the first point in the line.
    point_b         (tuple); (x, y) coordinates of the second point in the line.
    fly_parameters  (dict); dictionary of fly ellipse parameters.
    precision       (int); number of iterations to perform. The higher the number,
                           the more precise the computed (x, y) coordinates for the intersection point.
    
    Returns a tuple of (x, y) coordinates for the intersection point.
    Source: https://stackoverflow.com/questions/53705803/calculating-an-intercept-point-between-a-straight-line-and-an-ellipse-python
    '''
    
    # Get ellipse point intersected by line segment from aggressive head to mating maor axis intersection point.
    for i in range(precision):
        center_distance = ((point_b[0] - point_a[0]) / 2, (point_b[1] - point_a[1]) / 2)
        center_point = (point_a[0] + center_distance[0], point_a[1] + center_distance[1])
        if is_in_ellipse(center_point, fly_parameters):
            point_b = center_point
        else:
            point_a = center_point
            
    return center_point


def get_effect_size(control_data, test_data, method='hedge_g'):
    '''
    Calculates the effect size via Hedge's g or Glass's delta statistic, by comparing the degree of difference
    between a control group (data1), and a given test condition (data2), depending on whether samples have equal variance.
    Takes the following arguments:
    
    control_data:  (pandas.Series, numpy array, or list); control group.
    test_data:     (pandas.Series, numpy array, or list); test group.
    method:        (str); whether to calculate effect size using Hedge's g* or median difference.
                          Hedge's g* is the corrected form of Hedge's g. This test is applied only
                          if both samples have the same variance. Otherwise, Glass's delta is used.
    
    Returns Hedge's corrected g (g*) / Glass's delta, or median difference as a float.
    
    Sources: https://garstats.wordpress.com/2016/05/02/robust-effect-sizes-for-2-independent-groups/
             http://onlinestatbook.com/2/estimation/difference_means.html
             https://www.statisticshowto.com/cohens-d/
             https://www.statisticshowto.com/hedges-g/
             https://www.statisticshowto.com/glasss-delta/
             https://www.statisticshowto.com/pooled-standard-deviation/
             https://en.wikipedia.org/wiki/Effect_size
    '''
    
    if method == 'hedge_g':
        
        # Calculate the size of samples.
        control_size = len(control_data)
        test_size = len(test_data)
        sample_size = control_size + test_size

        control_degrees_of_freedom = control_size - 1
        test_degrees_of_freedom = test_size - 1

        # Calculate the variance of the samples.
        control_variance = np.var(control_data, ddof=control_degrees_of_freedom)
        test_variance = np.var(test_data, ddof=test_degrees_of_freedom)

        # Calculate the pooled standard deviation.
        pooled_standard_variation = np.sqrt((((control_size -1) * control_variance) + ((test_size -1) * test_variance)) / (sample_size - 2))

        # Calculate the means of the samples.
        control_mean = np.mean(control_data)
        test_mean = np.mean(test_data)

        combined_data = {'control': control_data, 'test': test_data}
        equal_variances = test_variance_homogeneity(combined_data)
        
        if equal_variances:

            # Calculate Hedge's g* (= corrected Hedge's g).
            hedge = (test_mean - control_mean) / pooled_standard_variation
            corrected_hedge = hedge * (1 - (3 / (4 * (sample_size -2) - 9)))

            print("Using Hedge's g*:")

            # Report the effect size.
            if abs(corrected_hedge) < 0.20:
                print('Vestigial Effect:', corrected_hedge, '\n')
            elif 0.20 <= abs(corrected_hedge) < 0.50:
                print('Small Effect:', corrected_hedge, '\n')
            elif 0.50 <= abs(corrected_hedge) < 0.80:
                print('Medium Effect:', corrected_hedge, '\n')
            elif abs(corrected_hedge) >= 0.80:
                print('Large Efect:', corrected_hedge, '\n')

            return corrected_hedge

        else:

            # Calculate Glass's delta.
            glass_delta = (test_mean - control_mean) / control_size

            print("Using Glass's delta:")

            # Report the effect size.
            if abs(glass_delta) < 0.20:
                print('Vestigial Effect:', glass_delta, '\n')
            elif 0.20 <= abs(glass_delta) < 0.50:
                print('Small Effect:', glass_delta, '\n')
            elif 0.50 <= abs(glass_delta) < 0.80:
                print('Medium Effect:', glass_delta, '\n')
            elif abs(glass_delta) >= 0.80:
                print('Large Efect:', glass_delta, '\n')

            return glass_delta


    if method == 'median_diff':
        
        # Get medians.
        control_median = np.median(control_data)
        test_median = np.median(test_data)

        # Get median difference and median fold change.
        median_diff = test_median - control_median
        fold_change = median_diff / control_median

        # Report the effect size.
        if abs(fold_change) < 0.20:
            print('Vestigial Effect:', fold_change, '\n')
        elif 0.20 <= abs(fold_change) < 0.50:
            print('Small Effect:', fold_change, '\n')
        elif 0.50 <= abs(fold_change) < 0.80:
            print('Medium Effect:', fold_change, '\n')
        elif abs(fold_change) >= 0.80:
            print('Large Efect:', fold_change, '\n')

        return fold_change


def get_intersection(slopes, intercepts):
    '''
    Computes intersection between two lines.
    Takes the following arguments:
    
    slopes:     (list); list containing the slopes of the two lines to compare.
    intercepts  (list); list containing the y-intercepts of the two lines to compare.
    
    Returns a tuple of (x, y) coordinates for the instersection point of the two lines,
    or None if the lines don't intersect.
    '''
    
    # Check if any of the lines has infinite slope (i.e., they are parallel to one of the main axes),
    # or if the two lines are parallel to each other; in either case intersection is impossible.
    if np.isinf(slopes[0]) or np.isinf(slopes[1]):
        print('One of the slopes is infinite!')
        return None
    elif slopes[0] == slopes[1]:
        if intercepts[0] == intercepts[1]:
            print('Lines are parallel and colinear.')
            return None
        else:
            print('Lines are parellel, no intersection possible.')
            return None
    
    # Get coordinates.
    intersect_x = (intercepts[1] - intercepts[0]) / (slopes[0] - slopes[1])
    intersect_y0 = round(slopes[0] * intersect_x + intercepts[0], 5)
    intersect_y1 = round(slopes[1] * intersect_x + intercepts[1], 5)
    
    # Check that the intercept calculated from both lines match up.
    if intersect_y0 == intersect_y1:
        intersect_y = intersect_y0
    else:
        print('Y intersects do not match:\ny1 =', intersect_y0, '\ny2 =', intersect_y1)
        return None

    intersect_coords = (intersect_x, intersect_y)

    return intersect_coords


def get_parameters(frame, behavior_dict, PPM, inner_color='gold'):
    '''
    Computes several descriptive parameters for the fly's current ellipse.
    Takes the following arguments:
    
    frame:          (int); the frame from which we want to extract fly parameters.
    behavior_dict:  (dict); dictionary containing a pandas DataFrame for each of the two females.
    PPM:            (float); pixel-per-millimeter conversion factor.
    inner_color:    (str); color to use for the inside of the ellipse.
    
    Returns a dictionary of parameters for the current frame's ellipse characteristics.
    '''
    
    # Define starting dictionary to store our parameters in.
    params = {'aggressive': {}, 'mating': {}}
    
    for behavior, df in behavior_dict.items():
    
        # Getting tracking parameters for given frame.
        x_centroid = df.loc[frame, 'pos_x'] / PPM
        y_centroid = abs((df.loc[frame, 'pos_y'] - 380) / PPM)
        width = df.loc[frame, 'major_axis_len_mm']
        height = df.loc[frame, 'minor_axis_len_mm']
        angle = df.loc[frame, 'ori_deg'] + 180

        # Get ellipse.
        ellipse = mpl.patches.Ellipse((x_centroid, y_centroid), width, height, angle=angle, facecolor=inner_color)

        # Get triangle sides length to calculate position of head and tail relative to ellipse center.
        opposite = np.sin(np.deg2rad(-angle)) * (df.loc[frame, 'major_axis_len_mm'] / 2)
        adjacent = np.cos(np.deg2rad(-angle)) * (df.loc[frame, 'major_axis_len_mm'] / 2)

        # Get head and tail coordinates.
        x_head = x_centroid - adjacent
        x_tail = x_centroid + adjacent
        y_head = y_centroid + opposite
        y_tail = y_centroid - opposite

        ### GET MAJOR AXIS LINE EQUATION.

        # Calculate slope of major axis.
        major_x_diff = x_tail - x_head
        major_y_diff = y_tail - y_head
        if major_x_diff == 0 or major_x_diff is np.nan:
            major_slope = np.nan
        else:
            major_slope = major_y_diff / major_x_diff

        # Calculate intercept of major axis.
        major_intercept = y_head - major_slope * x_head

        # Get line that goes through and extends beyond the fly's major axis.
        major_x = np.array([0, 40])
        major_y = major_slope * major_x + major_intercept

        ### GET FLANKS

        # Get opposite and adjacent sides for minor axis.
        minor_opposite = np.sin(np.deg2rad(-angle-90)) * (df.loc[frame, 'minor_axis_len_mm'] / 2)
        minor_adjacent = np.cos(np.deg2rad(-angle-90)) * (df.loc[frame, 'minor_axis_len_mm'] / 2)

        # Get flank coordinates.
        x_left_flank = x_centroid - minor_adjacent
        x_right_flank = x_centroid + minor_adjacent
        y_left_flank = y_centroid + minor_opposite
        y_right_flank = y_centroid - minor_opposite

        ### GET MINOR AXIS LINE EQUATION.

        # Calculate slope of minor axis.
        minor_x_diff = x_left_flank - x_right_flank
        minor_y_diff = y_left_flank - y_right_flank
        if minor_x_diff == 0 or minor_x_diff is np.nan:
            minor_slope = np.nan
        else:
            minor_slope = minor_y_diff / minor_x_diff

        # Calculate intercept of major axis.
        minor_intercept = y_left_flank - minor_slope * x_left_flank

        # Get line that goes through and extends beyond the fly's minor axis.
        minor_x = np.array([0, 40])
        minor_y = minor_slope * minor_x + minor_intercept

        # Save properties to dictionary.
        params[behavior]['centroid'] = (x_centroid, y_centroid)
        params[behavior]['width'] = width
        params[behavior]['height'] = height
        params[behavior]['angle'] = angle
        params[behavior]['ellipse'] = ellipse
        params[behavior]['head'] = (x_head, y_head)
        params[behavior]['tail'] = (x_tail, y_tail)
        params[behavior]['major_slope'] = major_slope
        params[behavior]['major_intercept'] = major_intercept
        params[behavior]['minor_slope'] = minor_slope
        params[behavior]['minor_intercept'] = minor_intercept
        params[behavior]['major_x'] = major_x
        params[behavior]['major_y'] = major_y
        params[behavior]['minor_x'] = minor_x
        params[behavior]['minor_y'] = minor_y
        params[behavior]['left_flank'] = (x_left_flank, y_left_flank)
        params[behavior]['right_flank'] = (x_right_flank, y_right_flank)

    # Calculate intersection point of the two flies' major axis lines.
    slopes = [params['aggressive']['major_slope'], params['mating']['major_slope']]
    intercepts = [params['aggressive']['major_intercept'], params['mating']['major_intercept']]
    intersect_coords = get_intersection(slopes, intercepts)
    if intersect_coords is None:
        return None
    
    # Get intersection point of the aggressive female's major body axis with the mated female's ellipse.
    ellipse_point = get_body_point(params['aggressive']['head'], intersect_coords, params['mating'])

    # Get intersection point of the mating female's major body axis with the line that is parallel to its minor axis and passes through ellipse body point.
    ellipse_point_axis_intercept = ellipse_point[1] - params['mating']['minor_slope'] * ellipse_point[0]
    new_slopes = [params['mating']['major_slope'], params['mating']['minor_slope']]
    new_intercepts = [params['mating']['major_intercept'], ellipse_point_axis_intercept]
    body_intersection_point = get_intersection(new_slopes, new_intercepts)
    if body_intersection_point is None:
        return None

    # Calculate the angle formed by the ellipse body point relative to the mating female's centroid and major axis.
    center_to_new_intersection = round(np.sqrt((body_intersection_point[0] - params['mating']['centroid'][0])**2 + (body_intersection_point[1] - params['mating']['centroid'][1])**2), 5)
    ellipse_point_to_center = round(np.sqrt((ellipse_point[0] - params['mating']['centroid'][0])**2 + (ellipse_point[1] - params['mating']['centroid'][1])**2), 5)
    theta_center_body_point = np.rad2deg(np.arccos(center_to_new_intersection / ellipse_point_to_center))

    # Get the arc made by the calculated angle above so we can draw it later and make sure the angles make sense.
    distance_to_head = np.sqrt((body_intersection_point[0] - params['mating']['head'][0])**2 + (body_intersection_point[1] - params['mating']['head'][1])**2)
    distance_to_tail = np.sqrt((body_intersection_point[0] - params['mating']['tail'][0])**2 + (body_intersection_point[1] - params['mating']['tail'][1])**2)
    distance_to_left_flank = np.sqrt((ellipse_point[0] - params['mating']['left_flank'][0])**2 + (ellipse_point[1] - params['mating']['left_flank'][1])**2)
    distance_to_right_flank = np.sqrt((ellipse_point[0] - params['mating']['right_flank'][0])**2 + (ellipse_point[1] - params['mating']['right_flank'][1])**2)
    if distance_to_tail <= distance_to_head and distance_to_right_flank <= distance_to_left_flank:
        # arc_theta1 = 0
        # arc_theta2 = theta_center_body_point
        # absolute_theta = theta_center_body_point
        arc_theta1 = theta_center_body_point
        arc_theta2 = 180
        absolute_theta = 180 - theta_center_body_point
    elif distance_to_tail > distance_to_head and distance_to_right_flank <= distance_to_left_flank:
        # arc_theta1 = 0
        # arc_theta2 = 180 - theta_center_body_point
        # absolute_theta = 180 - theta_center_body_point
        arc_theta1 = 180 - theta_center_body_point
        arc_theta2 = 180
        absolute_theta = theta_center_body_point
    elif distance_to_tail <= distance_to_head and distance_to_right_flank > distance_to_left_flank:
        # arc_theta1 = -theta_center_body_point
        # arc_theta2 = 0
        # absolute_theta = -theta_center_body_point
        arc_theta1 = 180
        arc_theta2 = -theta_center_body_point
        absolute_theta = -180 + theta_center_body_point
    elif distance_to_tail > distance_to_head and distance_to_right_flank > distance_to_left_flank:
        # arc_theta1 = 180 + theta_center_body_point
        # arc_theta2 = 0
        # absolute_theta = -180 + theta_center_body_point
        arc_theta1 = 180
        arc_theta2 = 180 + theta_center_body_point
        absolute_theta = -theta_center_body_point

    # Create the Arc object.
    mating_arc = mpl.patches.Arc(params['mating']['centroid'],
                                 width=2*center_to_new_intersection*(3/4),
                                 height=2*center_to_new_intersection*(3/4),
                                 angle=params['mating']['angle'],
                                 theta1=arc_theta1,
                                 theta2=arc_theta2,
                                 linewidth=2,
                                 color='gold'
                                )
    
    # Calculate distance from head of aggressive female to mating fly body intersection point.
    distance_to_body_point = round(np.sqrt((params['aggressive']['head'][0] - body_intersection_point[0])**2 + (params['aggressive']['head'][1] - body_intersection_point[1])**2), 5)
    
    # Store additional parameters in our initial dictionary.
    params['intersect_coords'] = intersect_coords
    params['ellipse_point'] = ellipse_point
    params['body_intersection_point'] = body_intersection_point
    params['absolute_theta'] = absolute_theta
    params['mating_arc'] = mating_arc
    params['dist_to_aggressor'] = distance_to_body_point

    return params


def is_in_ellipse(point, ellipse_parameters):
    '''
    Checks whether a given point is inside or outside a given ellipse.
    Takes the following arguments:
    
    point:               (tuple); (x, y) coordinates of point to test.
    ellipse_parameters:  (matplotli.patch.Ellipse); a matplotlib Ellipse obect to check.
    
    Returns a boolean, indicating if the point is inside (True), or outside (False) the ellipse.
    Source: https://stackoverflow.com/questions/37031356/check-if-points-are-inside-ellipse-faster-than-contains-point-method
    '''

    cos_angle = np.cos(np.radians(180 - ellipse_parameters['angle']))
    sin_angle = np.sin(np.radians(180 - ellipse_parameters['angle']))

    x_diff = point[0] - ellipse_parameters['centroid'][0]
    y_diff = point[1] - ellipse_parameters['centroid'][1]

    x_transform = x_diff * cos_angle - y_diff * sin_angle
    y_transform = x_diff * sin_angle + y_diff * cos_angle 

    radius = (x_transform ** 2 / (ellipse_parameters['width'] / 2) ** 2) + (y_transform ** 2 / (ellipse_parameters['height'] / 2) ** 2)

    if radius < 1: # point is inside ellipse
        return True
    else:
        return False


def is_intersecting(ellipse_parameters):
    '''
    Check if the intersection point of the two female's lines lies inside a given ellipse.
    Takes the following arguments:
    
    ellipse_parameters  (dict); dictionary of fly ellipse parameters.
    
    Returns a boolean: whether the intersection point is inside (True) or outside (False) the given ellipse.
    '''

    if is_in_ellipse(ellipse_parameters['intersect_coords'], ellipse_parameters['mating']):
        return True
    else:
        return False


def load_JAABA_behavior(cop_annotation, condition, behavior, file_path):
    '''
    Reads the appropriate JAABA scores file for a given condition and behavior depending
    on the comment contained in the corresponding annotation file.
    Takes the following arguments:
    
    cop_annotation    (Annotation object) annotation file containing copulation data.
    condition         (str) condition to read from. Dictates how to load scores.
    behavior          (str) behavior to load from scores.
    file_path         (str) file to read scores file from.

    Returns the start frame for the first courtship bout if there is one copulation event,
    or a list containing the start frame of the first courtship bouts before each copulation event.
    '''
    
    events = cop_annotation[0].events
    n_events = len(events)
    
    if n_events==1:
        if condition=='male_competition':
            if events[0].comment.split('_')[0]=='unpainted':
                behavior_cop = read_behavior_predictions(file_path, 0, behavior+'_1_2')
            else:
                behavior_cop = read_behavior_predictions(file_path, 1, behavior+'_2_3')

        elif condition=='virgin_virgin':
            if events[0].comment.split('_')[0]=='unpainted':
                behavior_cop = read_behavior_predictions(file_path, 0, behavior+'_1_2')
            else:
                behavior_cop = read_behavior_predictions(file_path, 0, behavior+'_1_3')

        elif condition=='single_pair':
            behavior_cop = read_behavior_predictions(file_path, 0, behavior)
            
        elif condition=='virgin_added':
            behavior_cop = read_behavior_predictions(file_path+'_slice1', 0, behavior)

        else:
            if events[0].comment.split('_')[1]=='virgin':
                behavior_cop = read_behavior_predictions(file_path, 0, behavior+'_1_2')
            else:
                behavior_cop = read_behavior_predictions(file_path, 0, behavior+'_1_3')

        behavior_on, _ = detect_events(behavior_cop)
        first_event = behavior_on[0]

        return first_event

    elif n_events==2: 
        if condition=='male_competition':
            if events[0].comment.split('_')[0]=='unpainted':
                behavior_cop1 = read_behavior_predictions(file_path, 0, behavior+'_1_2')
                behavior_cop2 = read_behavior_predictions(file_path, 1, behavior+'_2_3')
            else:
                behavior_cop1 = read_behavior_predictions(file_path, 1, behavior+'_2_3')
                behavior_cop2 = read_behavior_predictions(file_path, 0, behavior+'_1_2')

        elif condition=='virgin_virgin':
            if events[0].comment.split('_')[0]=='unpainted':
                behavior_cop1 = read_behavior_predictions(file_path, 0, behavior+'_1_2')
                behavior_cop2 = read_behavior_predictions(file_path, 0, behavior+'_1_3')
            else:
                behavior_cop1 = read_behavior_predictions(file_path, 0, behavior+'_1_3')
                behavior_cop2 = read_behavior_predictions(file_path, 0, behavior+'_1_2')

        elif condition=='virgin_added':
            behavior_cop1 = read_behavior_predictions(file_path+'_slice1', 0, behavior)
            behavior_cop2 = pd.Series([False] * len(behavior_cop1))

        else:
            if events[0].comment.split('_')[1]=='virgin':
                behavior_cop1 = read_behavior_predictions(file_path, 0, behavior+'_1_2')
                behavior_cop2 = read_behavior_predictions(file_path, 0, behavior+'_1_3')
            else:
                behavior_cop1 = read_behavior_predictions(file_path, 0, behavior+'_1_3')
                behavior_cop2 = read_behavior_predictions(file_path, 0, behavior+'_1_2')

        behavior_on1, _ = detect_events(behavior_cop1)
        behavior_on2, _ = detect_events(behavior_cop2)

        first_cop_end = cop_annotation[0].events[0].time_interval[1]
        first_event = behavior_on1[0]

        after_first_cop = behavior_on2[behavior_on2>first_cop_end]
        second_event = first_cop_end if condition=='virgin_added' else after_first_cop[0]

        return [first_event, second_event]


def make_plot(positions_df, df1=None, df2=None, figure_size=(4,4), xlim=390, ylim=380, circle_margin=15, circle_color='black',
              df1_color='g', df2_color='firebrick', plot_style='.', hist2d=False, bin_size=10, colormap1='Greens', colormap2='Reds', min_count=1):
    
    '''
    Takes a pandas DataFrame and produces a plot of xy positions from it.
    Takes the followin arguments:
    
    positions_df:   (pandas DataFrame); the original DataFrame to access.
    df1:            (pandas DataFrame); the first DataFrame to plot.
    df2:            (pandas DataFrame); a second, optional DataFrame to plot.
    figure_size:    (tuple); tuple of two integers to determine figure size.
    xlim:           (int); sets max x limit on plot.
    ylim:           (int); sets max y limit on plot.
    circle_margin:  (int); sets the spacing between the plotted points and the circle's edge.
    circle_color:   (str); sets the color for the edge of the circle.
    df1_color:      (str); sets the color of the points in the plot.
    df2_color:      (str); sets the color of the points in the plot.
    plot_style:     (str); sets the drawing style of the points in the plot.
    hist2d:         (bool); whether to plot a 2d-histogram ("heatmap") or not (in which case a regular point plot is drawn).
    bin_size:       (int); size of bins to use for the 2d-histogram.
    colormap1:      (str); matplotlib color set to use when plotting the 2d-histogram.
    colormap2:      (str); matplotlib color set to use when plotting the second, optional 2d-histogram.
    min_count:      (int); minimum number of counts in a given bin to plot in color (bins with counts below this values are plotted as transparent).
    
    Doesn't return anything, but creates a matplotlib.pyplot figure.
    '''
        
    fig, ax = plt.subplots(figsize=figure_size)

    ax.set_xlim(0, xlim)
    ax.set_xticks([0, 100, 200, 300, 390])
    ax.set_ylim(0, ylim)
    ax.set_yticks([0, 50, 100, 150, 200, 250, 300, 350, 380])

    # Get centre and radius for the circle representing the arena boundaries.
    min_x = positions_df['pos_x'].min()
    max_x = positions_df['pos_x'].max()
    min_y = positions_df['pos_y'].min()
    max_y = positions_df['pos_y'].max()
    rad_x = int((max_x - min_x) / 2)
    rad_y = int((max_y - min_y) / 2)
    radius = max(rad_x, rad_y)
    centre = (min_x + radius, min_y + radius)
    
    # Draw the circle and add it to the plot.
    margin = circle_margin
    arena = plt.Circle(centre, radius+margin, edgecolor=circle_color, fill=False)
    ax.add_artist(arena)
    
    # Create bins for specific plot types.
    plot_bins = [np.linspace(0, xlim+1, int(xlim/bin_size)), np.linspace(0, ylim+1 ,int(ylim/bin_size))]
    
    # Actual plot.
    if hist2d is True:
        if df1 is not None:
            plt.hist2d(df1['pos_x'], df1['pos_y'], bins=plot_bins, cmap=colormap1, cmin=min_count)
        if df2 is not None:
            plt.hist2d(df2['pos_x'], df2['pos_y'], bins=plot_bins, cmap=colormap2, cmin=min_count)
    else:
        if df1 is not None:
            plt.plot(df1['pos_x'], df1['pos_y'], linestyle='', marker=plot_style, color=df1_color)
        if df2 is not None:
            plt.plot(df2['pos_x'], df2['pos_y'], linestyle='', marker=plot_style, color=df2_color)

    return ax


def plot_behaviors(series, color='steelblue', facealpha=1, offset=0, patch_height=1):
    '''
    Generate a list of matplotlib Rectangle patch objects, where each Rectangle patch corresponds to a
    behavior bout detected by detect_events().
    Takes the following arguments:
    
    series:        (pandas.Series); a pandas Series.
    color:         (str); facecolor for all the Rectangle objects.
    facealpha:     (float); transparency setting for all Rectangle objects. Goes from 0 (invisible) to 1 (opaque).
    offset:        (int); Where on the plot's y axis we want to start drawing our Rectangles. Default is 1
                          (i.e, starting at the bottom of the plot).
    patch_height:  (str); rectangle height.

    Returns a list of matplotlib.Rectangle objects and a list of their respective widths. These will be used
    to draw raster plots of several different annoitated behaviors.
    '''
    
    behavior_on, behavior_off = detect_events(series)
    
    # Use start and end frames to determine each bout's duration (which will become the plotted rectangle's width).
    widths = [off - on for on, off in zip(behavior_on, behavior_off)]

    # Set height and y coordinate for each rectangle.
    patch_offset = [offset] * len(behavior_on)
    heights = [patch_height] * len(behavior_on)
    
    # Actually create all the rectangles with our specifications, and store them in a list so we can plot them later.
    behavior_rects = []
    patch_settings = zip(behavior_on, patch_offset, widths, heights)
    for x, y, w, h in patch_settings:
        behavior_rect = mpatches.Rectangle((x, y), w, h, facecolor=color, alpha=facealpha, linewidth=0)
        behavior_rects.append(behavior_rect)
    
    return behavior_rects, widths


def plot_dot(series, dot_width=150 , dot_height=0.3 , dot_color='red', dot_edgecolor='black', facealpha=1, offset=0, size=10):
    '''
    Same functionality as plot_behaviors(), but draws matplotlib Ellipse patch objects.
    Takes the following arguments:
    
    series:        (pandas.Series); a pandas Series.
    dot_width:     (int); width of the Ellipse.
    dot_height:    (float); height of the Ellipse.
    dot_color:     (str); facecolor for all the Ellipse objects.
    facealpha:     (float); transparency setting for all Ellipse objects. Goes from 0 (invisible) to 1 (opaque).
    offset:        (int); Where on the plot's y axis we want to start drawing our Ellipses. Default is 1
                          (i.e, starting at the bottom of the plot).

    Returns a list of matplotlib.Cricle objects for drawing the corresponding behavior on a raster plot.
    '''
    
    behavior_on, _ = detect_events(series)

    behavior_circles = []
    circle_offset = [offset] * len(behavior_on)
    circle_settings = zip(behavior_on, circle_offset)
    for x, y in circle_settings:
        behavior_circle = mpatches.Ellipse((x, y), width=dot_width, height=dot_height, facecolor=dot_color, edgecolor=dot_edgecolor, alpha=facealpha)
        behavior_circles.append(behavior_circle)

    return behavior_circles


def plot_polar_histogram(angles, angles_bin_size=10, mode='absolute', theta_zero_loc='S', rlabels=True, rlabels_pos=195, radial_limit=None,
                         radial_bins=None, figure_size=(6, 10), ink='black'):
    """
    Produce a circular histogram of angles on a matplotlib axis.
    Takes the following arguments:

    angles:            (numpy array); angular data in array format. All angles must be in radians.
    angles_bin_size:   (int); bin size for the angle data. Given in degrees.
    mode:              (str); the kind of plot to create. Must be one of the following:
                              "absolute" - use absolute counts. Frequency matches bar length.
                              "normalized" - use percentage of counts. Frequency matches bar length.
                              "density" - use percentage of counts. Frequency matches bar area.
    theta_zero_loc:    (str); location to place zero degrees at.
                              Default is 'S' for 'South', which places it at the bottom of the figure.
    rlabels:           (bool); whether or not to draw the radial labels.
    rlabels_pos:       (int); the position, in degree angles, at which to draw the radial labels.
    radial_limit:      (float); the y limit for the radial axis.
    radial_bins:       (int); the number of bins to use from zero to the radial limit.
    figure_size:       (tuple) size (width, height) of the drawn figure.
    ink                (str); color to use for figure elements.

    Returns a matplotlib.Axis element, where a polar heatmap is drawn.
    
    density : bool, optional
        If True plot frequency proportional to area. If False plot frequency
        proportional to radius. The default is True.
    """
    
    # Ensure that all angles are within -180º and 180º.
    assert -np.pi <= angles.all() <= np.pi

    # Convert bin size to radians.
    angles_bin_size_rad = np.deg2rad(angles_bin_size)

    # Force bins to partition entire circle.
    angles_bins = np.array(np.arange(-np.pi, np.pi+0.01, angles_bin_size_rad))

    # Bin data and record counts.
    counts, bins = np.histogram(angles, bins=angles_bins)
    
    # Compute width of each bin.
    widths = np.diff(bins)

    # Calculate the appropriate radius (i.e., length) of the bars based on the bin counts and selected mode.
    if mode == 'density':
        # Area to assign each bin.
        area = counts / angles.size
        # Calculate corresponding bin radius.
        radius = np.sqrt(area/np.pi)
    elif mode == 'normalized':
        # Calculate normalized fractions.
        radius = counts / angles.size
    elif mode == 'absolute':
        # Otherwise plot frequency proportional to radius.
        radius = counts
    else:
        raise ValueError('keyword argument "mode" must be one of "absolute", "density" or "normalized".')

    # Initiate a polar plot on radial axes.
    fig = plt.figure(figsize=figure_size, facecolor='white' if ink=='black' else 'black')
    ax = plt.subplot(projection='polar')
    
    # Plot data on axis.
    ax.bar(bins[:-1], radius, zorder=1, align='edge', width=widths, edgecolor=ink, facecolor='steelblue', linewidth=1)

    # Format radial grid labels.
    theta_ticks = np.linspace(start=0, stop=360, num=12, endpoint=False)
    theta_labels = [str(int(deg_ang*-1)) if deg_ang < 180 else str(int((deg_ang-360)*-1)) for deg_ang in theta_ticks]
    theta_labels_fmt = ['{}\N{DEGREE SIGN}'.format(label) for label in theta_labels]
    ax.set_thetagrids(angles=theta_ticks, labels=theta_labels_fmt, fontsize=16, color=ink)
    ax.set_facecolor('white' if ink=='black' else 'black')
    ax.spines['polar'].set_color(ink)
    ax.set_theta_zero_location(loc=theta_zero_loc)
    ax.tick_params(axis='x', pad=18)
    
    # Format radial grid lines.
    if rlabels:
        if radial_limit is None:
            y_ticks = ax.get_yticks()
            max_tick = max(y_ticks)
            # tick_spacing = np.diff(y_ticks)[0]
            
            # ax.set_rlim(0, max_tick + tick_spacing)
            ax.set_rlim(0, max_tick)
        else:
            ax.set_rlim(0, radial_limit)
            ax.set_yticks([y for y in ax.get_yticks()] if radial_bins is None else np.linspace(start=0, stop=radial_limit, num=radial_bins+1, endpoint=True))
        ax.grid(True)
        ax.set_rlabel_position(value=rlabels_pos)
        ax.tick_params(axis='y', labelsize=12)
    else:
        ax.set_yticks([])
        ax.set_yticklabels([])

    return ax


def plot_polar_relative_heatmap(angles, radius, angles_bin_size=10, radius_bin_limit=10, radius_bin_size=1, theta_zero_loc='S', rlabels=True, show_metric=True,
                                metric='mm', rlabels_pos=195, cbar=True, normalization='min_max', ax=None, colormap='RdYlBu_r', figure_size=(6, 10), ink='black'):
    '''
    Creates a polar plot from angle data.
    Takes the following arguments:
    
    angles:            (numpy array); angular data in array format. All angles must be in radians.
    radius:            (numpy array); data to use for the radial dimension. Can be distance, angles, etc.
    angles_bin_size:   (int); bin size for the angle data. Given in degrees.
    radius_bin_limit:  (int); final edge for the last bin.
    radius_bin_size:   (int); bin size for the data in the radial dimension.
    theta_zero_loc:    (str); location to place zero degrees at.
                              Default is 'S' for 'South', which places it at the bottom of the figure.
    rlabels:           (bool); whether or not to draw the radial labels.
    show_metric:       (bool); whether or not to append "mm" to the radial labels.
    metric:            (str); which metric to append to the radial labels.
    rlabels_pos:       (int); the position, in degree angles, at which to draw the radial labels.
    cbar:              (bool); whether or not to draw a colorbar.
    normalization:     (str); Normalization strategy. Can be one of 'min_max' or 'proba'. If 'min_max', counts are
                              normalized via MinMax method, i.e., the minimum count becomes 0 and the maximum count
                              ecomes 1; if 'proba', counts are normalized by the sample size, reflecting true
                              probabilities. Any other input returns absolute counts. 
    ax:                (matplotlib.Axes); axis object to use, if any.
    colormap:          (str); color palette to use in the colorar.
    figure_size:       (tuple) size (width, height) of the drawn figure.
    ink                (str); color to use for figure elements.

    Returns a matplotlib.Axis element, where a polar heatmap is drawn.
    '''
    
    # Ensure that all angles are within -180º and 180º.
    assert -np.pi <= angles.all() <= np.pi
    
    # Convert bin size to radians.
    angles_bin_size_rad = np.deg2rad(angles_bin_size)

    # Calculate bins.
    angles_bins = np.array(np.arange(-np.pi, np.pi+0.01, angles_bin_size_rad))
    radius_bins = np.array(range(0, radius_bin_limit+1, radius_bin_size))

    # Define polar mesh grid.
    theta_grid, radius_grid = np.meshgrid(angles_bins, radius_bins, indexing='ij')
    
    # Define bins and bin edges.
    H, xedges, yedges = np.histogram2d(angles, radius, bins=(angles_bins, radius_bins))

    # Mask pixels with a value of zero.
    H = np.ma.masked_where(H==0, H)
    
    # Normalize counts.
    if normalization=='min_max':
        H = H / np.max(H)
        mesh_vmax = 1
        cbar_ticks = (0, 1)
        cbar_label = 'Normalized Counts'
        cbar_padding = -10
    elif normalization=='proba':
        H = H / np.sum(H)
        mesh_vmax = 1
        cbar_ticks = (0, 1)
        cbar_label = 'Normalized Counts'
        cbar_padding = -10
    else:
        mesh_vmax = None
        cbar_ticks = np.linspace(0, max(H), num=4, endpoint=True)
        cbar_label = 'Absolute Counts'
        cbar_padding = 6

    # Create the polar plot.
    fig = plt.figure(figsize=figure_size, facecolor='white' if ink=='black' else 'black')
    ax = ax or plt.subplot(projection='polar')
    
    pax = ax.pcolormesh(theta_grid, radius_grid, H, cmap=colormap, vmin=0, vmax=mesh_vmax, zorder=0, antialiased=True)

    # Add a color bar.
    if cbar:
        cbar = plt.gcf().colorbar(pax, ax=ax, ticks=cbar_ticks, pad=0.10, orientation='horizontal', shrink=0.80)
        cbar.set_label(cbar_label, labelpad=cbar_padding)
        cbar.set_ticklabels([str(round(float(tick), 1)) if mesh_vmax is None else tick for tick in cbar.ax.get_xticks()])
        cbar.ax.tick_params(labelsize=16)
    
    # Format radial grid labels.
    theta_ticks = np.linspace(start=0, stop=360, num=12, endpoint=False)
    theta_labels = [str(int(deg_ang*-1)) if deg_ang < 180 else str(int((deg_ang-360)*-1)) for deg_ang in theta_ticks]
    theta_labels_fmt = ['{}\N{DEGREE SIGN}'.format(label) for label in theta_labels]
    ax.set_thetagrids(angles=theta_ticks, labels=theta_labels_fmt, fontsize=16, color=ink)
    ax.set_facecolor('white' if ink=='black' else 'black')
    ax.spines['polar'].set_color(ink)
    ax.set_theta_zero_location(loc=theta_zero_loc)
    ax.tick_params(axis='x', pad=18)
    
    # Format radial grid lines.
    ax.grid(True)
    ax.set_rlabel_position(value=rlabels_pos)
    if not rlabels:
        ax.set_yticklabels([])
    else:
        ax.set_yticks(range(0, radius_bin_limit+1, radius_bin_size))
        yticklabels = [str(y) + metric if show_metric==True else str(y) for y in ax.get_yticks()]
        yticklabels[0] = ''
        ax.set_yticklabels(yticklabels)
    ax.set_rgrids(radii=[float(n) for n in range(0, radius_bin_limit+1, radius_bin_size)], color=ink)
    ax.tick_params(axis='y', labelsize=13)
    ax.set_rlim(0, radius_bin_limit)
    
    if cbar:
        return ax
    else:
        return ax, pax


def plot_stattest_result(ax, x1, x2, p_value, y, alpha=0.05, ticksize=1, connector_color='black', **kwargs):
    '''
    Takes a p-value and turns it into a significance stars representation, also drawing the line and ticks connecting
    the two groups being compared.
    Takes the following arguments:

    ax:               (matplotlib.Axes); axis element on which to draw the lines and stars.
    x1:               (int, float); x coordinate for the start of the line.
    x2:               (int, float); x coordinate for the end of the line.
    p_value:          (float, str); p-value to be converted into stars (e.g., p-value=0.0028 -> '**').
    y:                (int, float); y coordinate for the height of the line.
    alpha:            (float); significance level.
    ticksize:         (int, float); length of the ticks to be drawn.
    connector_color:  (str); color with which to draw the lines and ticks.
    **kwargs:          Optional keyword arguments. These include:
                            - xytext: tuple of xy coordinates where to draw the text.
                            - textcoords: The coordinate system that xytext is given in..
                            - annotation_clip: whether or not to clip the lines and text at the edge of the axis.
                            - ha: horizontal alignment of text.
                            - va: vertical alignment of text.

    Doesn't return any object or variable. Draws a line with ticks at its ends and a significance star representation on top.
    '''

    def sigstars(p_value, alpha):
        if isinstance(p_value, str):
            return p_value

        if p_value < (alpha / 500):
            return '****'
        elif (p_value < (alpha / 50)):
            return '***'
        elif (p_value < (alpha / 5)):
            return '**'
        elif (p_value < alpha):
            return '*'
        else:
            return 'ns'

    stars = sigstars(p_value, alpha)

    options = {
               'xytext': (0, 5 if stars == 'ns' else 2.5),
               'textcoords': 'offset points',
               'annotation_clip': False,
               'ha': 'center',
               'va': 'center',
              }

    options.update(kwargs)

    connector = plt.Line2D([x1, x1, x2, x2], [y - ticksize, y, y, y - ticksize], color=connector_color, linewidth=1)
    connector.set_clip_on(False)
    ax.add_line(connector)
    ax.annotate(stars, xy=((x1 + x2) / 2, y), **options)

    
def process_track_data(experiment_path, include_aggression_only=True, trackfeat_columns=['pos x', 'pos y'], is_control=False, fly_id=1, suppress=False):
    '''
    Reads specific columns from a FlyTracker -trackeat.csv file and outputs them in a pandas DataFrame.
    Takes the following arguments:
    
    experiment_path:          (str); filepath to the experiment's annotation .csv file.
    include_aggression_only:  (bool); whether to keep only videos with aggression, or get all videos with copulation, irrespective of aggression.
    trackfeat_columns:        (list); list of columns to use from FlyTracker's -trackfeat.csv file.
    is_control:               (bool); whether the current experiment belongs to the control group or not. This is important for fetching the correct fly identities.
    fly_id:                   (int); which fly's information to fecth. Check the appropriate txt file for a full description of fly identities. Can take values 1 or 2.
    suppress:                 (bool); whether or not to suppress print statements.
    
    Returns a tuple with two pandas DataFrames (one for each female fly), a tuple of copulation (start and end), and a list of aggression bout (start and end),
    or a tuple of "None"s if there is no copulation or no aggression.
    '''

    experiment = os.path.basename(experiment_path)

    # Read annotations file.
    annotation_data = annotations.read(experiment_path + '.csv')

    # Disregard videos without copulation and get start and end frames for first copulation.
    copulation_track = annotation_data[0]
    if len(copulation_track.events) != 0:
        first_copulation = copulation_track.events[0]
        first_copulation_start = first_copulation.time_interval[0]
        first_copulation_end = first_copulation.time_interval[1]
    else:
        if not suppress:
            print('\nNo copulation for experiment ' + experiment_name + '.\n')
        return (None, None, None, None)

    # Disregard videos without aggression and get start and end frames for each aggression bout.
    aggressive_track = annotation_data[1]
    if len(aggressive_track.events) != 0:
        aggression_events = aggressive_track.events
        aggression_timepoints = [(event.time_interval[0], event.time_interval[1]) for event in aggression_events]
    else:
        if include_aggression_only:
            if not suppress:
                print('\nNo aggression for experiment ' + experiment_name + '.\n')
            return (None, None, None, None)
        else:
            aggression_timepoints = []

    if is_control:
        # Check the comment for the first copulation in order to get the correct identity of the aggressive female.
        comment = annotation_data[0].events[0].comment
        if fly_id == 1:
            if comment.split('_')[0] == 'painted':
                mating_subject = 'pair_1_2_subject1'
                aggressive_subject = 'pair_1_2_subject2'
            else:
                mating_subject = 'pair_1_3_subject1'
                aggressive_subject = 'pair_1_3_subject3'
        else:
            if comment.split('_')[0] == 'painted':
                mating_subject = 'pair_2_3_subject3'
                aggressive_subject = 'pair_2_3_subject2'
            else:
                mating_subject = 'pair_2_3_subject2'
                aggressive_subject = 'pair_2_3_subject3'
    else:
        if fly_id == 1:
            mating_subject = 'pair_1_3_subject1'
            aggressive_subject = 'pair_1_3_subject3'
        else:
            mating_subject = 'pair_2_3_subject2'
            aggressive_subject = 'pair_2_3_subject3'
        
    # Read trackfeat file.
    aggressive_data = pd.read_csv(os.path.join(experiment_path, experiment+'-trackfeat', aggressive_subject+'.csv'), usecols=trackfeat_columns)
    mating_data = pd.read_csv(os.path.join(experiment_path, experiment+'-trackfeat', mating_subject+'.csv'), usecols=trackfeat_columns)

    aggressive_data.rename(columns={feature: feature.replace(' ', '_') for feature in trackfeat_columns}, inplace=True)
    mating_data.rename(columns={feature: feature.replace(' ', '_') for feature in trackfeat_columns}, inplace=True)
    
    if not suppress:
        print(experiment)
    
    return (aggressive_data, mating_data, (first_copulation_start, first_copulation_end), aggression_timepoints)


# Taken from original courtana's "jaaba.py".
def read_behavior_predictions(output_folder, target, behavior):
    """Read JAABA behavior predictions of a movie processed using
    FlyTracker.

    output_folder: path to FlyTracker output folder of a video
    target:        number of the fly as JAABA labeled it (0,...,N)
    behavior:      behavior name of the jab file

    Returns a boolean pandas.Series whose indexes are the frames and the
    truthfulness indicating the behavior prediction.
    """
    foldername = os.path.split(output_folder)[-1]
    scores_filepath = os.path.join(
        output_folder, foldername + '_JAABA', 'scores_' + behavior + '.mat')
    return read_scores(scores_filepath, target)


def read_scores(filename, target):
    """
    Read JAABA behavior predictions into a `pandas.Series`.

    filename: a scores_*.mat file
    target:   number of the fly as JAABA labeled it (0,...,N)

    Returns a boolean pandas.Series whose indexes are the frames and the
    truthfulness indicating the behavior prediction.
    """
    mat = loadmat(filename)
    mdata = mat['allScores']
    scores = {n: mdata[n][0, 0] for n in mdata.dtype.names}
    return pd.Series(scores['postprocessed'][0, target][0], dtype=bool)


def remove_short_events(events, min_duration=5, suppress=False):
    '''
    Removes events whose duration is below a given threshold.
    Takes the following arguments:
    
    events:         (list); a list of tuples representing a group of events.
    min_duration:   (int); the minimum duration each event to have in order to be kept.
                           This is our threshold.
    suppress:       (bool); whether or not to show print statements.
    
    Returns a modified list of tuples, each tuple containing the start and end frame of each event.
    '''
    
    # Calculate duration of each event.
    durations = [event[1] - event[0] for event in events]
    shorts = np.array(durations) <= min_duration

    # Iterate through all events, and retain only those that are above our threshold.
    clean = []
    for s, short in enumerate(shorts):
        if not short:
            clean.append(events[s])
        else:
            if not suppress:
                print('Skipped event:', events[s], str(events[s][1] - events[s][0]))

    return clean


def remove_outliers(data, indices, suppress=False):
    '''
    Removes outliers from the given dataset at the given indices positions.
    Takes the following arguments:
    
    data:      (pandas.Series, numpy array, or list); the dataset from which to remove outliers.
    indices:   (list); the indices at which the given dataset's outliers are located.
    suppress:  (bool); whether or not to suppress print statements.
    
    Returns a modified pandas Series without outliers.
    '''
    
    sample_size = len(data)
    
    fresh_data = data.drop(index=indices)
    
    percentage_drops = len(indices) / sample_size
    
    if not suppress:
        print('Number data points considered outliers:', len(indices))
        print('Percent data points considered outliers:', round(percentage_drops * 100, 3), '%')
    
    if percentage_drops > 0.10:
        print('More than 10% of the dataset has been attributed to outliers, consider checking your data!')
    
    return fresh_data


def run_statistics(data, alpha=0.05, threshold=0.9, ind_measures=True, contingency=False, value_name='value', var_name='variable'):
    '''
    Statistical suite. Tests for variance homogeneity (Levene's test) and normality of
    distribution (Shapiro-Wilk and D'Agostino tests) for parametric assumptions. Based on
    that, computes ANOVA or its non-parametric equivalent, Kruskal-Wallis. Finally, if any
    of these tests detects significantly different means, a post-hoc Tukey-HSD test is run
    to determine which groups' means are significantly different.
    Takes the following arguments:
    
    data:                    (dict); a dictionary where the group names or identities are the dictionary
                                     keys, each observation or measurement for that group is stored in a
                                     list, array, or pandas Series as the dictionary values.
    alpha:                   (float); Significance level for our statistical tests.
    threshold:               (float); Fraction of normality tests that have to return positive in order
                                      to assume that the data are normally distributed.
    ind_measures:            (bool); whether samples to compare come from independent measurements
                                     (True as default) or if the samples com from repeated measures of
                                     the same individual(s).
    contingency:             (bool); whether or not the input data is a contingency table.
    value_name:              (str); Column name for the value column after melting our data in
                                    preparation for post-hoc analysis.
    var_name:                (str); Column name for the variable column after melting our data in
                                    preparation for post-hoc analysis.

    Returns the calculated p-value.
    '''

    assert isinstance(data, dict), 'Input data must be a dictionary.'

    if contingency:
        if len(data.keys()) == 2:
            oddsratio, p = stats.fisher_exact(np.array(list(data.values())))
            #print('Fischer exact test odds ratio:', oddsratio)
            print('\nFischer exact test p-value:', p, '\n')
        else:
            raise NotImplementedError
            chi2, p, dof, expected = stats.chi2_contingency(*data.values())  # TO TEST
            #print('Chi-square for contingency tables statistic:', chi2)
            print('\nChi-square for contingency tables p-value:', p, '\n')
            #print('Chi-square for contingency tables degrees of freedom:', dof)

        return p

    else:

        normtest = []
        for n, values in enumerate(data.values()):
            W, shapiro_p = stats.shapiro(values)
            normally_distributed = shapiro_p > alpha
            report = ("Shapiro's Test: group {} {} normally distributed.")
            print(report.format(str(n+1), "IS" if normally_distributed else "IS NOT"))
            normtest.append(shapiro_p)

            k, dagostino_p = stats.normaltest(values)
            normally_distributed = dagostino_p > alpha
            report = ("D'Agostino's Test: group {} {} normally distributed.")
            print(report.format(str(n+1), "IS" if normally_distributed else "IS NOT"))
            normtest.append(dagostino_p)

        is_normal = np.array(normtest) > alpha
        ntrue = np.count_nonzero(is_normal)
        truth_amount = ntrue / np.size(is_normal)
        mostly_normal = truth_amount > threshold

        if mostly_normal:
            T, p = stats.bartlett(*data.values())  # TO TEST
            report = ("Bartlett's Test for normally distribted samples:\n"
                      "  p-value = {:5f}\n"
                      "  All groups were sampled from populations with {} variances.")
            print(report.format(p, "IDENTICAL" if p > alpha else "NOT IDENTICAL"))
        else:
            W, p = stats.levene(*data.values())
            report = ("Levene's Test for non-normally distributed samples:\n"
                      "  p-value = {:5f}\n"
                      "  All groups were sampled from populations with {} variances.")
            print(report.format(p, "IDENTICAL" if p > alpha else "NOT IDENTICAL"))
        
        is_homogeneous = p > alpha

        if len(data.keys()) == 2:
            if ind_measures==True:
                if is_homogeneous and mostly_normal:
                    t, p = stats.ttest_ind(*data.values(), equal_var=True)
                    #print('Independent measures (two sample) t-test t statistic:', t)
                    print('\nIndependent measures (two sample) t-test p-value:', p, '\n')

                elif not is_homogeneous and mostly_normal:
                    t, p = stats.ttest_ind(*data.values(), equal_var=False)
                    #print("Independent measures (two sample) t-test t statistic with Welch's correction for unequal variances:", t)
                    print("\nIndependent measures (two sample) t-test p-value with Welch's correction for unequal variances:", p, '\n')
                    
                else:
                    U, p = stats.mannwhitneyu(*data.values())
                    #print('Mann-Whitney U statistic:', U)
                    print('\nMann-Whitney p-value:', p, '\n')
            
            else:
                if is_homogeneous and mostly_normal:
                    t, p = stats.ttest_rel(*data.values())
                    #print('Repeated measures (paired) t-test t statistic:', t)
                    print('\nRepeated measures (paired) t-test p-value:', p, '\n')

                else:
                    W, p = stats.wilcoxon(list(data.values())[0], list(data.values())[1])
                    #print('Wilcoxon signed-rank test W statistic:', W)
                    print('\nWilcoxon signed-rank test p-value:', p, '\n')
        
            return p

        if len(data.keys()) > 2:
            if ind_measures==True:
                if is_homogeneous and mostly_normal:
                    F, p = stats.f_oneway(*data.values())
                    #print('ANOVA F statistic:', F)
                    print('\nANOVA p-value:', p, '\n')

                else:
                    H, p = stats.kruskal(*data.values())
                    #print('Kruskal-Wallis H statistic:', H)
                    print('\nKruskal-Wallis p-value:', p, '\n')

                p_value_dict = {}
                if p < alpha:
                    data_df = pd.DataFrame.from_dict(data, orient='columns')
                    data_df = pd.melt(data_df, value_vars=list(data_df.columns), value_name=value_name, var_name=var_name).dropna()
                    data_df = data_df[[value_name, var_name]]
                    print(data_df)

                    result = pairwise_tukeyhsd(data_df[value_name], data_df[var_name], alpha=alpha)
                    print(result)

                    groups = ['-'.join(element) for element in itertools.combinations(result.groupsunique, 2)]
                    tukey_p = psturng(np.abs(result.meandiffs / result.std_pairs), len(result.groupsunique), result.df_total)
                    for group, tp in zip(groups, tukey_p):
                        print(group, 'p-value:', tp)
                        p_value_dict.update({group: tp})
            
                return (p, p_value_dict)

            else:
                if is_homogeneous and mostly_normal:
                    raise NotImplementedError
                    p = AnovaRM(*data.values(), depvar=None, subject=None)  # TO TEST
                    print('\nANOVA p-value:', p, '\n')

                else:
                    raise NotImplementedError
                    statistic, p = stats.friedmanchisquare(*data.values())  # TO TEST
                    #print('Friedman chi-square statistic:', statistic)
                    print('\nFriedman chi-square p-value:', p, '\n')

                return p


def series_to_annotation(starts, ends, file_path, behavior, fly_num):
    '''
    Takes a list of start events and a list of end events generated by the detect_events()
    function to convert a pandas Series into an Annotation file readable by python Video Annotator.
    Takes the following arguments:
    
    starts:        (list) list of integers representing the start of each behavioral bout.
    ends:          (list) list of integers representing the end of each behavioral bout.
    file_path:     (str) path to save our csv file to.
    behavior:      (str) behavior of interest.
    fly_num:       (str) identifier for the fly.

    Returns no object, but creates a .csv file.
    '''
    
    # Set name of the file.
    filename = os.path.join(file_path + '_' + behavior + '_' + fly_num + '.csv')

    # Set information to save into the file.
    if os.path.isfile(filename):
        os.remove(filename)
    color = '#228b22' if behavior=='aggression' else '#6495ed'  # Else being courtship.
    track = 'T ' + behavior + ' ' + color + '\n'
    track_comma = track.replace(' ', ',')
    timestamps =  zip(starts, ends)

    # Save track information into a csv file.
    with open(filename, 'a') as file:
        file.write(track_comma)
        for t in timestamps:
            line = 'P False ' + str(t[0]) + ' ' + str(t[1]) + ',,' + color + ' 0\n'
            line_comma = line.replace(' ', ',')
            file.write(line_comma)



def step_through(df, start=0, end=216000+1, step=60):
    '''
    Takes a DataFrame and generates new DataFrame with only every n-th row, where n is the "step" argument value.
    Takes the following arguments:

    df:     (pandas DataFrame); the DataFrame to sample.
    start:  (int); row from which to start the sampling.
    end:    (int); row at which to end the sampling.
    step:   (int); how many rows to skip when sampling.

    Returns a modified pandas DataFrame.
    '''

    stepped_df = pd.DataFrame()
    for i in range(start, end, step):
        df_slice = df.loc[df.index==i]
        stepped_df = pd.concat([stepped_df, df_slice], axis=0)

    return stepped_df


def stitch_events(event_list, window=60, suppress=False):
    '''
    Turns events that are very close together temporally into single events, reducing fragmentation.
    Takes the following arguments:
    
    event_list:   (list); a list of tuples representing a group of events.
    window:       (int); how many frames apart do events have to be from one another to be kept separate.
                         This is our threshold.
    suppress:     (bool); whether or not to show print statements.
    
    Returns a modified list of tuples, each tuple containing the start and end frame of each event.
    '''

    stitched_events = []
    
    # Calculate frame distance between detected events.
    frame_spacing = [abs(event_list[i][1] - event_list[i+1][0]) for i in range(len(event_list)-1)]
    if not suppress:
        print('Frame spacing between detected "encounters" is:', frame_spacing)
    
    # Get a boolean array of which events are within window frames of each other.
    neighborhood = np.array(frame_spacing) <= window
    if not suppress:
        print('Low spacing:', neighborhood, sum(neighborhood))

    # The 'valid_index' variable will keep track of which indexes to skip in the original list on order to ignore events that are stitched together.
    valid_index = 0
    for index, space in enumerate(frame_spacing):
        if index != valid_index:
            continue
        n_stitches = 0
        if space <= window:
            stitch = (event_list[index][0], event_list[index+1][1])
            # Check if we have to stitch consecutive events.
            for count, is_near in enumerate(neighborhood[index+1:]):
                n_stitches = count + 1
                if is_near:
                    stitch = (stitch[0], event_list[index+1+n_stitches][1])
                    valid_index += 1
                else:
                    valid_index += 1
                    break
            if n_stitches > 0:
                valid_index += 1
            stitched_events.append(stitch)
        else:
            stitched_events.append(event_list[index])
            valid_index += 1

    # If the second-to-last event and last event were not stitched together, append the original last event to the stitched list.
    if stitched_events[-1][1] != event_list[-1][1]:
        stitched_events.append(event_list[-1])
    
    if not suppress:
        print(stitched_events, len(stitched_events))
    
    # Take the now stitched events and clean them, i.e., remove any exceedingly short events, since they are most likely garbage.
    cleaned_events = remove_short_events(stitched_events, suppress=True)
    
    if not suppress:
        print(cleaned_events, len(cleaned_events))

    return cleaned_events


def test_normality(data, alpha=0.05):
    '''
    Test normality of data distribution, using Shapiro's test and D'Agostino's test. These test should not be
    significant to meet the assumption of normality of distributions. If the p-value is small, you reject the
    null hypothesis that both groups were sampled from populations with approximatel normal distributions.
    Takes the following arguments:

    data:   (dict); input data to test. Must be a dictionary.
    alpha:  (float); significance level.

    Returns a dictionary containing the corresponding p-value of each test for the given alpha level.
    '''

    assert isinstance(data, dict), 'Input data must be a dictionary.'

    p_values = {'shapiro': 0, 'agostino': 0}
    for n, values in enumerate(data.values()):
        W, shapiro_p = stats.shapiro(values)
        normally_distributed = shapiro_p > alpha
        report = ("Shapiro's Test: group {} {} normally distributed.")
        print(report.format(str(n+1), "IS" if normally_distributed else "IS NOT"))
        p_values['shapiro'] = shapiro_p

        k, agostino_p = stats.normaltest(values)
        normally_distributed = agostino_p > alpha
        report = ("D'Agostino's Test: group {} {} normally distributed.")
        print(report.format(str(n+1), "IS" if normally_distributed else "IS NOT"))
        p_values['agostino'] = agostino_p

    return p_values


def test_variance_homogeneity(data, alpha=0.05):
    '''
    Test variance homogeneity using Levene's test. This test should not be significant to meet the assumption
    of equality of variances. If the p-value is small, you reject the null hypothesis that both groups were
    sampled from populations with identical standard deviations (and thus identical variances).
    Takes the following arguments:

    data:   (dict); input data to test. Must be a dictionary.
    alpha:  (float); significance level.

    Returns a tuple containing the test statistic, and the corresponding p-value for the given alpha level.
    '''

    assert isinstance(data, dict), 'Input data must be a dictionary.'

    statistic, p = stats.levene(*data.values())
    report = ("Levene's Test:\n"
              "  p-value = {:5f}\n"
              "  All groups were sampled from populations with {} variances.")
    print(report.format(p, "IDENTICAL" if p > alpha else "NOT IDENTICAL"))

    return (statistic, p)
